{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "premade.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "private_outputs": true,
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.2"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "1Z6Wtb_jisbA"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "colab_type": "code",
        "id": "QUyRGn9riopB",
        "colab": {}
      },
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "H1yCdGFW4j_F"
      },
      "source": [
        "# 预创建的 Estimators"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "PS6_yKSoyLAl"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://tensorflow.google.cn/tutorials/estimator/premade\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\" />在 tensorFlow.google.cn 上查看</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/zh-cn/tutorials/estimator/premade.ipynb\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\" />在 Google Colab 中运行</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/zh-cn/tutorials/estimator/premade.ipynb\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\" />在 GitHub 上查看源代码</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/zh-cn/tutorials/estimator/premade.ipynb\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\" />下载 notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FgdA9XE5ZCS3",
        "colab_type": "text"
      },
      "source": [
        "Note: 我们的 TensorFlow 社区翻译了这些文档。因为社区翻译是尽力而为， 所以无法保证它们是最准确的，并且反映了最新的\n",
        "[官方英文文档](https://www.tensorflow.org/?hl=en)。如果您有改进此翻译的建议， 请提交 pull request 到\n",
        "[tensorflow/docs](https://github.com/tensorflow/docs) GitHub 仓库。要志愿地撰写或者审核译文，请加入\n",
        "[docs-zh-cn@tensorflow.org Google Group](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-zh-cn)。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "R4YZ_ievcY7p"
      },
      "source": [
        "\n",
        "本教程将向您展示如何使用 Estimators 解决 Tensorflow 中的鸢尾花（Iris）分类问题。Estimator 是 Tensorflow 完整模型的高级表示，它被设计用于轻松扩展和异步训练。更多细节请参阅 [Estimators](https://tensorflow.google.cn/guide/estimator)。\n",
        "\n",
        "请注意，在 Tensorflow 2.0 中，[Keras API](https://tensorflow.google.cn/guide/keras) 可以完成许多相同的任务，而且被认为是一个更易学习的API。如果您刚刚开始入门，我们建议您从 Keras 开始。有关 Tensorflow 2.0 中可用高级API的更多信息，请参阅 [Keras标准化](https://medium.com/tensorflow/standardizing-on-keras-guidance-on-high-level-apis-in-tensorflow-2-0-bad2b04c819a)。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "8IFct0yedsTy"
      },
      "source": [
        "## 首先要做的事\n",
        "\n",
        "为了开始，您将首先导入 Tensorflow 和一系列您需要的库。\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "jPo5bQwndr9P",
        "colab": {}
      },
      "source": [
        "from __future__ import absolute_import, division, print_function, unicode_literals\n",
        "\n",
        "try:\n",
        "  # Colab only\n",
        "  %tensorflow_version 2.x\n",
        "except Exception:\n",
        "    pass\n",
        "\n",
        "import tensorflow as tf\n",
        "\n",
        "import pandas as pd"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "c5w4m5gncnGh"
      },
      "source": [
        "## 数据集\n",
        "\n",
        "本文档中的示例程序构建并测试了一个模型，该模型根据[花萼](https://en.wikipedia.org/wiki/Sepal)和[花瓣](https://en.wikipedia.org/wiki/Petal)的大小将鸢尾花分成三种物种。\n",
        "\n",
        "您将使用鸢尾花数据集训练模型。该数据集包括四个特征和一个[标签](https://developers.google.com/machine-learning/glossary/#label)。这四个特征确定了单个鸢尾花的以下植物学特征：\n",
        "\n",
        "* 花萼长度\n",
        "* 花萼宽度\n",
        "* 花瓣长度\n",
        "* 花瓣宽度\n",
        "\n",
        "根据这些信息，您可以定义一些有用的常量来解析数据：\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "lSyrXp_He_UE",
        "colab": {}
      },
      "source": [
        "CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'Species']\n",
        "SPECIES = ['Setosa', 'Versicolor', 'Virginica']"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "j6mTfIQzfC9w"
      },
      "source": [
        "接下来，使用 Keras 与 Pandas 下载并解析鸢尾花数据集。注意为训练和测试保留不同的数据集。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "PumyCN8VdGGc",
        "colab": {}
      },
      "source": [
        "train_path = tf.keras.utils.get_file(\n",
        "    \"iris_training.csv\", \"https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv\")\n",
        "test_path = tf.keras.utils.get_file(\n",
        "    \"iris_test.csv\", \"https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv\")\n",
        "\n",
        "train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)\n",
        "test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "wHFxNLszhQjz"
      },
      "source": [
        "通过检查数据您可以发现有四列浮点型特征和一列 int32 型标签。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "WOJt-ML4hAwI",
        "colab": {}
      },
      "source": [
        "train.head()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "jQJEYfVvfznP"
      },
      "source": [
        "对于每个数据集都分割出标签，模型将被训练来预测这些标签。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sSaJNGeaZCTG",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "train_y = train.pop('Species')\n",
        "test_y = test.pop('Species')\n",
        "\n",
        "# 标签列现已从数据中删除\n",
        "train.head()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "jZx1L_1Vcmxv"
      },
      "source": [
        "## Estimator 编程概述\n",
        "\n",
        "现在您已经设定好了数据，您可以使用 Tensorflow Estimator 定义模型。Estimator 是从 `tf.estimator.Estimator` 中派生的任何类。Tensorflow提供了一组`tf.estimator`(例如，`LinearRegressor`)来实现常见的机器学习算法。此外，您可以编写您自己的[自定义 Estimator](https://tensorflow.google.cn/guide/custom_estimators)。入门阶段我们建议使用预创建的 Estimator。\n",
        "\n",
        "为了编写基于预创建的 Estimator 的 Tensorflow 项目，您必须完成以下工作：\n",
        "\n",
        "* 创建一个或多个输入函数\n",
        "* 定义模型的特征列\n",
        "* 实例化一个 Estimator，指定特征列和各种超参数。\n",
        "* 在 Estimator 对象上调用一个或多个方法，传递合适的输入函数以作为数据源。\n",
        "\n",
        "我们来看看这些任务是如何在鸢尾花分类中实现的。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "2OcguDfBcmmg"
      },
      "source": [
        "## 创建输入函数\n",
        "\n",
        "您必须创建输入函数来提供用于训练、评估和预测的数据。\n",
        "\n",
        "**输入函数**是一个返回 `tf.data.Dataset` 对象的函数，此对象会输出下列含两个元素的元组：\n",
        "\n",
        "* [`features`](https://developers.google.com/machine-learning/glossary/#feature)——Python字典，其中：\n",
        "    * 每个键都是特征名称\n",
        "    * 每个值都是包含此特征所有值的数组\n",
        "* `label` 包含每个样本的[标签](https://developers.google.com/machine-learning/glossary/#label)的值的数组。\n",
        "\n",
        "为了向您展示输入函数的格式，请查看下面这个简单的实现：\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "nzr5vRr5caGF",
        "colab": {}
      },
      "source": [
        "def input_evaluation_set():\n",
        "    features = {'SepalLength': np.array([6.4, 5.0]),\n",
        "                'SepalWidth':  np.array([2.8, 2.3]),\n",
        "                'PetalLength': np.array([5.6, 3.3]),\n",
        "                'PetalWidth':  np.array([2.2, 1.0])}\n",
        "    labels = np.array([2, 1])\n",
        "    return features, labels"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "NpXvGjfnjHgY"
      },
      "source": [
        "\n",
        "您的输入函数可以以您喜欢的方式生成 `features` 字典与 `label` 列表。但是，我们建议使用 Tensorflow 的 [Dataset API](https://tensorflow.google.cn/guide/datasets)，该 API 可以用来解析各种类型的数据。\n",
        "\n",
        "Dataset API 可以为您处理很多常见情况。例如，使用 Dataset API，您可以轻松地从大量文件中并行读取记录，并将它们合并为单个数据流。\n",
        "\n",
        "为了简化此示例，我们将使用 [pandas](https://pandas.pydata.org/) 加载数据，并利用此内存数据构建输入管道。\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "T20u1anCi8NP",
        "colab": {}
      },
      "source": [
        "def input_fn(features, labels, training=True, batch_size=256):\n",
        "    \"\"\"An input function for training or evaluating\"\"\"\n",
        "    # 将输入转换为数据集。\n",
        "    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))\n",
        "\n",
        "    # 如果在训练模式下混淆并重复数据。\n",
        "    if training:\n",
        "        dataset = dataset.shuffle(1000).repeat()\n",
        "    \n",
        "    return dataset.batch(batch_size)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "xIwcFT4MlZEi"
      },
      "source": [
        "## 定义特征列（feature columns）\n",
        "\n",
        "[**特征列（feature columns）**](https://developers.google.com/machine-learning/glossary/#feature_columns)是一个对象，用于描述模型应该如何使用特征字典中的原始输入数据。当您构建一个 Estimator 模型的时候，您会向其传递一个特征列的列表，其中包含您希望模型使用的每个特征。`tf.feature_column` 模块提供了许多为模型表示数据的选项。\n",
        "\n",
        "对于鸢尾花问题，4 个原始特征是数值，因此我们将构建一个特征列的列表，以告知 Estimator 模型将 4 个特征都表示为 32 位浮点值。故创建特征列的代码如下所示：\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "ZTTriO8FlSML",
        "colab": {}
      },
      "source": [
        "# 特征列描述了如何使用输入。\n",
        "my_feature_columns = []\n",
        "for key in train.keys():\n",
        "    my_feature_columns.append(tf.feature_column.numeric_column(key=key))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "jpKkhMoZljco"
      },
      "source": [
        "特征列可能比上述示例复杂得多。您可以从[指南](https://tensorflow.google.cn/guide/feature_columns)获取更多关于特征列的信息。\n",
        "\n",
        "我们已经介绍了如何使模型表示原始特征，现在您可以构建 Estimator 了。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "kuE59XHEl22K"
      },
      "source": [
        "## 实例化 Estimator\n",
        "\n",
        "鸢尾花为题是一个经典的分类问题。幸运的是，Tensorflow 提供了几个预创建的 Estimator 分类器，其中包括：\n",
        "\n",
        "* `tf.estimator.DNNClassifier` 用于多类别分类的深度模型\n",
        "* `tf.estimator.DNNLinearCombinedClassifier` 用于广度与深度模型\n",
        "* `tf.estimator.LinearClassifier` 用于基于线性模型的分类器\n",
        "\n",
        "对于鸢尾花问题，`tf.estimator.LinearClassifier` 似乎是最好的选择。您可以这样实例化该 Estimator：\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "qnf4o2V5lcPn",
        "colab": {}
      },
      "source": [
        "# 构建一个拥有两个隐层，隐藏节点分别为 30 和 10 的深度神经网络。\n",
        "classifier = tf.estimator.DNNClassifier(\n",
        "    feature_columns=my_feature_columns,\n",
        "    # 隐层所含结点数量分别为 30 和 10.\n",
        "    hidden_units=[30, 10],\n",
        "    # 模型必须从三个类别中做出选择。\n",
        "    n_classes=3)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "tzzt5nUpmEe3"
      },
      "source": [
        " ## 训练、评估和预测\n",
        "\n",
        "我们已经有一个 Estimator 对象，现在可以调用方法来执行下列操作：\n",
        "\n",
        "* 训练模型。\n",
        "* 评估经过训练的模型。\n",
        "* 使用经过训练的模型进行预测。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "rnihuLdWmE75"
      },
      "source": [
        "### 训练模型\n",
        "\n",
        "通过调用 Estimator 的 `Train` 方法来训练模型，如下所示："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "4jW08YtPl1iS",
        "colab": {}
      },
      "source": [
        "# 训练模型。\n",
        "classifier.train(\n",
        "    input_fn=lambda: input_fn(train, train_y, training=True),\n",
        "    steps=5000)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ybiTFDmlmes8"
      },
      "source": [
        "注意将 ` input_fn` 调用封装在 [`lambda`](https://docs.python.org/3/tutorial/controlflow.html) 中以获取参数，同时提供不带参数的输入函数，如 Estimator 所预期的那样。`step` 参数告知该方法在训练多少步后停止训练。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "HNvJLH8hmsdf"
      },
      "source": [
        "### 评估经过训练的模型\n",
        "\n",
        "现在模型已经经过训练，您可以获取一些关于模型性能的统计信息。代码块将在测试数据上对经过训练的模型的准确率（accuracy）进行评估：\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "A169XuO4mKxF",
        "colab": {}
      },
      "source": [
        "eval_result = classifier.evaluate(\n",
        "    input_fn=lambda: input_fn(test, test_y, training=False))\n",
        "\n",
        "print('\\nTest set accuracy: {accuracy:0.3f}\\n'.format(**eval_result))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "VnPMP5EHph17"
      },
      "source": [
        "与对 `train` 方法的调用不同，我们没有传递 `steps` 参数来进行评估。用于评估的 `input_fn` 只生成一个 [epoch](https://developers.google.com/machine-learning/glossary/#epoch) 的数据。\n",
        "\n",
        "`eval_result` 字典亦包含 `average_loss`（每个样本的平均误差），`loss`（每个 mini-batch 的平均误差）与 Estimator 的 `global_step`（经历的训练迭代次数）值。 "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ur624ibpp52X"
      },
      "source": [
        "### 利用经过训练的模型进行预测（推理）\n",
        "\n",
        "我们已经有一个经过训练的模型，可以生成准确的评估结果。我们现在可以使用经过训练的模型，根据一些无标签测量结果预测鸢尾花的品种。与训练和评估一样，我们使用单个函数调用进行预测："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "wltc0jpgng38",
        "colab": {}
      },
      "source": [
        "# 由模型生成预测\n",
        "expected = ['Setosa', 'Versicolor', 'Virginica']\n",
        "predict_x = {\n",
        "    'SepalLength': [5.1, 5.9, 6.9],\n",
        "    'SepalWidth': [3.3, 3.0, 3.1],\n",
        "    'PetalLength': [1.7, 4.2, 5.4],\n",
        "    'PetalWidth': [0.5, 1.5, 2.1],\n",
        "}\n",
        "\n",
        "def input_fn(features, batch_size=256):\n",
        "    \"\"\"An input function for prediction.\"\"\"\n",
        "    # 将输入转换为无标签数据集。\n",
        "    return tf.data.Dataset.from_tensor_slices(dict(features)).batch(batch_size)\n",
        "\n",
        "predictions = classifier.predict(\n",
        "    input_fn=lambda: input_fn(predict_x))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "JsETKQo0rHvi"
      },
      "source": [
        "`predict` 方法返回一个 Python 可迭代对象，为每个样本生成一个预测结果字典。以下代码输出了一些预测及其概率："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Efm4mLzkrCxp",
        "colab": {}
      },
      "source": [
        "for pred_dict, expec in zip(predictions, expected):\n",
        "    class_id = pred_dict['class_ids'][0]\n",
        "    probability = pred_dict['probabilities'][class_id]\n",
        "\n",
        "    print('Prediction is \"{}\" ({:.1f}%), expected \"{}\"'.format(\n",
        "        SPECIES[class_id], 100 * probability, expec))"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}